-
Notifications
You must be signed in to change notification settings - Fork 158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow prior on gpu #519
Allow prior on gpu #519
Conversation
This is great, thanks! It'll also be useful e.g. when the prior is a previous posterior. I do not have a strong opinion on the assertion -- I have a slight preference for automatically moving the prior and giving a warning though. This is also what we do if the simulations lie on the GPU, so it feels natural to do the same thing for the prior. As I said, no strong opinions though. Feel free to merge if you think |
thanks @michaeldeistler regarding moving the prior to GPU, I haven't found a way for doing this actually. One would have to create a new prior instance with the parameters living on gpu. and for that one basically would need a long |
That makes sense. I think my preferred way would then be to stick to the current implementation using |
Codecov Report
@@ Coverage Diff @@
## main #519 +/- ##
==========================================
+ Coverage 67.70% 67.79% +0.08%
==========================================
Files 55 55
Lines 3970 3965 -5
==========================================
Hits 2688 2688
+ Misses 1282 1277 -5
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
@michaeldeistler I refactored to potential function to handle the devices efficiently. Do you still approve? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks good, thanks. I made two comments. I think replacing get_potential
with posterior_potential
would make sense, no?
1989fbd
to
5fa0770
Compare
slow tests are passing now. I had to fix small things but importantly I had to remove the |
* remove hmc with uniform prior test, too slow.
Relates to #515
Up to now we assumed the prior live on the
cpu
and we moved samples to.cpu()
whenever combininglog_probs
from theprior
and thenet
.This PR now allows the prior to live on the
GPU
, and itasserts
that the prior lives on the same device as the passeddevice
for trainingPro: we don't need to move to
.cpu()
all the timeCon:
numpy
based MCMC methods naturally happen on thecpu
. when evaluating thetheta
on the prior, we now have to move it to theprior
device (instead of doingnet.logprob(theta).cpu()
)AssertionError
when the prior was not on initialised on thedevice
. This was not the case before. Alternatively, we could introduce or deduceprior_device
and take care of moving things around internally. Any opinions on that?I haven't profiled it, but I think this way of doing it is faster than the old way because we move things less. And if we one day implement the vectorized MCMC in
torch
, we might get speed ups when running that on the GPU then.closes #515